//////////////////////////////////////////////
// main.cpp
//
//////////////////////////////////////////////

/// Includes ---------------------------------

// Local
#include "PastCode.h"

// nkGraphics
#include <NilkinsGraphics/Buffers/Buffer.h>
#include <NilkinsGraphics/Buffers/BufferManager.h>

#include <NilkinsGraphics/Compositors/Compositor.h>
#include <NilkinsGraphics/Compositors/CompositorManager.h>
#include <NilkinsGraphics/Compositors/CompositorNode.h>
#include <NilkinsGraphics/Compositors/TargetOperations.h>

#include <NilkinsGraphics/Passes/ClearTargetsPass.h>
#include <NilkinsGraphics/Passes/PostProcessPass.h>
#include <NilkinsGraphics/Passes/RaytracingPass.h>

#include <NilkinsGraphics/RenderContexts/RenderContext.h>
#include <NilkinsGraphics/RenderContexts/RenderContextDescriptor.h>
#include <NilkinsGraphics/RenderContexts/RenderContextManager.h>

#include <NilkinsGraphics/Renderers/Renderer.h>

#include <NilkinsGraphics/System.h>

#include <thread>

/// Internals : Shaders ----------------------

void prepareRaygenMissProgram ()
{
	nkGraphics::Program* raygenMissProgram = nkGraphics::ProgramManager::getInstance()->createOrRetrieve("raygenMiss") ;
	nkGraphics::ProgramSourcesHolder sources ;

	// This program will only use raytracing stage
	sources.setRaytracingMemory
	(
		R"eos(
			// Payload that will transit with the ray as it traverses the scene
			struct RayPayload
			{
				float4 color ;
			} ;

			cbuffer PassConstants : register(b0)
			{
				uint4 texInfos ;
				float4 camPos ;
				matrix invView ;
				matrix invProj ;
			}

			// Acceleration structure is binded to a texture slot
			RaytracingAccelerationStructure scene : register(t0) ;

			RWStructuredBuffer<float4> output : register(u0) ;

			[shader("raygeneration")]
			void raygen ()
			{
				// Compute ray's origin
				float2 dispatchIndex = DispatchRaysIndex().xy ;
				float2 pixCenter = dispatchIndex.xy + 0.5 ;
				float2 uvs = pixCenter / texInfos.xy * 2.0 - 1.0 ;
				uvs.y = -uvs.y ;
				
				float3 pixelOrigin = camPos.xyz ;
				float4 pixelDir = mul(invView, mul(invProj, float4(uvs, 0, 1))) ;
				pixelDir.xyz /= pixelDir.w ;
				float3 pixelDirVec3 = normalize(pixelDir.xyz - pixelOrigin) ;

				// Prepare the ray description
				RayDesc ray ;
				ray.Origin = pixelOrigin ;
				ray.Direction = pixelDirVec3 ;
				ray.TMin = 0.001 ;
				ray.TMax = 100.0 ;
				RayPayload payload = {float4(1, 1, 1, 1)} ;

				// Simple tracing
				TraceRay(scene, RAY_FLAG_NONE, ~0, 0, 1, 0, ray, payload) ;

				// And write result into the uav buffer
				uint index = dispatchIndex.y * texInfos.x + dispatchIndex.x ;
				
				output[index] = float4(payload.color.xyz, 1) ;
			}

			[shader("miss")]
			void miss (inout RayPayload payload)
			{
				// We will have a black pixel when missing
				payload.color = float4(0, 0, 0, 0) ;
			}
		)eos"
	) ;

	raygenMissProgram->setFromMemory(sources) ;
	raygenMissProgram->load() ;
}

void prepareRaygenMissShader ()
{
	// Prepare the shader used by the pass for raygen miss purposes
	nkGraphics::Shader* raygenMissShader = nkGraphics::ShaderManager::getInstance()->createOrRetrieve("raygenMiss") ;
	nkGraphics::Program* raygenMissProgram = nkGraphics::ProgramManager::getInstance()->get("raygenMiss") ;

	raygenMissShader->setProgram(raygenMissProgram) ;

	// Constant Buffer needs many information
	nkGraphics::ConstantBuffer* cBuffer = raygenMissShader->addConstantBuffer(0) ;

	// Align the texture sizes based on the target
	nkGraphics::ShaderPassMemorySlot* slot = cBuffer->addPassMemorySlot() ;
	slot->setAsTargetSize() ;

	slot = cBuffer->addPassMemorySlot() ;
	slot->setAsCameraPosition() ;

	slot = cBuffer->addPassMemorySlot() ;
	slot->setAsViewMatrixInv() ;

	slot = cBuffer->addPassMemorySlot() ;
	slot->setAsProjectionMatrixInv() ;

	// The scene acceleration structure is accessible through the render queue
	nkGraphics::RenderQueue* rq = nkGraphics::RenderQueueManager::getInstance()->get(nkGraphics::RenderQueueManager::DEFAULT_RENDER_QUEUE) ;
	raygenMissShader->addTexture(rq->getAccelerationStructureBuffer(), 0) ;

	nkGraphics::Buffer* buffer = nkGraphics::BufferManager::getInstance()->get("raytracedBuffer") ;
	raygenMissShader->addUavBuffer(buffer, 0) ;

	// Finalize loading
	raygenMissShader->load() ;
}

void prepareHitProgram ()
{
	nkGraphics::Program* hitProgram = nkGraphics::ProgramManager::getInstance()->createOrRetrieve("hit") ;
	nkGraphics::ProgramSourcesHolder sources ;

	// This program will only use raytracing stage
	sources.setRaytracingMemory
	(
		R"eos(
			// Payload needs to be in line with the other programs
			struct RayPayload
			{
				float4 color;
			} ;

			[shader("closesthit")]
			void closestHit (inout RayPayload payload, in BuiltInTriangleIntersectionAttributes attr)
			{
				// A hit will have a yellowish color
				payload.color = float4(0.4, 0.4, 0.3, 1.0) ;
			}
		)eos"
	) ;

	hitProgram->setFromMemory(sources) ;
	hitProgram->load() ;
}

void prepareHitShader ()
{
	// Prepare the shader used by an entity on hit
	nkGraphics::Shader* hitShader = nkGraphics::ShaderManager::getInstance()->createOrRetrieve("hit") ;
	nkGraphics::Program* hitProgram = nkGraphics::ProgramManager::getInstance()->get("hit") ;

	hitShader->setProgram(hitProgram) ;

	// Finalize loading
	hitShader->load() ;
}

/// Internals : Compositor -------------------

nkGraphics::Compositor* prepareCompositor ()
{
	// Prepare the shader passes will require
	nkGraphics::Shader* envShader = nkGraphics::ShaderManager::getInstance()->get("envShader") ;
	nkGraphics::Shader* raygenMissShader = nkGraphics::ShaderManager::getInstance()->get("raygenMiss") ;
	nkGraphics::Shader* bufferCopyShader = nkGraphics::ShaderManager::getInstance()->get("bufferCopy") ;

	// Prepare the rq
	nkGraphics::RenderQueue* rq = nkGraphics::RenderQueueManager::getInstance()->get(nkGraphics::RenderQueueManager::DEFAULT_RENDER_QUEUE) ;

	// Get the compositor and create a node
	nkGraphics::Compositor* compositor = nkGraphics::CompositorManager::getInstance()->createOrRetrieve("compositor") ;
	nkGraphics::CompositorNode* node = compositor->addNode() ;

	// One operation to the back buffer
	nkGraphics::TargetOperations* targetOp = node->addOperations() ;
	targetOp->setToBackBuffer(true) ;
	targetOp->setToChainDepthBuffer(true) ;

	// Unroll our passes
	nkGraphics::ClearTargetsPass* clearPass = targetOp->addClearTargetsPass() ;

	// Here we will raytrace the scene, and for that the RaytracingPass is the way to go
	nkGraphics::RaytracingPass* rtPass = targetOp->addRaytracingPass() ;
	rtPass->setRq(rq) ;
	rtPass->setRaygenMissShader(raygenMissShader) ;
	// As a raytracing pass is like a compute pass and needs to have the size of the invocation
	// Let's base it on the target's dimensions
	rtPass->setWidth(800) ;
	rtPass->setHeight(600) ;

	// Eventually, what's left is to copy the resulting buffer to the target
	nkGraphics::PostProcessPass* postProcessPass = targetOp->addPostProcessPass() ;
	postProcessPass->setShader(bufferCopyShader) ;

	return compositor ;
}

/// Internals : raytracing -------------------

void prepareRaytracingInScene ()
{
	// Rendering queue needs to be flagged to be raytraced
	nkGraphics::RenderQueue* rq = nkGraphics::RenderQueueManager::getInstance()->get(nkGraphics::RenderQueueManager::DEFAULT_RENDER_QUEUE) ;
	rq->setRaytraced(true) ;

	// The entity also needs to know the shader to use on hit
	nkGraphics::Shader* hitShader = nkGraphics::ShaderManager::getInstance()->get("hit") ;
	nkGraphics::Entity* ent = rq->getEntity(0) ;
	//ent->getRenderInfo().getSlots()[0]->getLods()[0]->setRaytracingShader(hitShader) ;
}

/// Function ---------------------------------

int main ()
{
	std::this_thread::sleep_for(std::chrono::seconds(5)) ;
	// Prepare for logging
	std::unique_ptr<nkLog::Logger> logger = std::make_unique<nkLog::ConsoleLogger>() ;
	nkGraphics::LogManager::getInstance()->setReceiver(logger.get()) ;

	// For easiness
	nkResources::ResourceManager::getInstance()->setWorkingPath("Data") ;

	// Initialize and create context with window
	if (!nkGraphics::System::getInstance()->initialize())
		return -1 ;

	// Query now what the hardware is capable of
	// As we will be using raytracing, we need to be able to do so in the renderer
	nkGraphics::RendererSupportInfo supportInfo = nkGraphics::System::getInstance()->getRenderer()->getRendererSupportInfo() ;

	if (!supportInfo._supportsRaytracing)
	{
		logger->log("Current hardware does not support raytracing, this tutorial cannot be run on this machine.", "RtxTutorial") ;

		system("pause") ;
		return 0 ;
	}

	// Basic resource preparations
	baseInit() ;
	prepareRaytracedBuffer() ;

	// Raygen and miss shaders
	prepareRaygenMissProgram() ;
	prepareRaygenMissShader() ;

	// Hit shader
	prepareHitProgram() ;
	prepareHitShader() ;

	// Prepare raytracing setup in scene
	prepareRaytracingInScene() ;

	// Prepare the composition once everything is ready
	nkGraphics::Compositor* compositor = prepareCompositor() ;
	
	// Use the compositor for the context we create
	nkGraphics::RenderContext* context = nkGraphics::RenderContextManager::getInstance()->createRenderContext(nkGraphics::RenderContextDescriptor(800, 600, false, true)) ;
	context->setCompositor(compositor) ;

	// And trigger the rendering
	renderLoop(context) ;

	// Clean exit
	nkGraphics::System::getInstance()->kill() ;

	return 0 ;
}